{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Resumable backup and restore for very large indexes\n",
    "\n",
    "**This unofficial code sample is offered \"as-is\" and might not work for all customers and scenarios. If you run into difficulties, you should manually recreate and reload your search index on a new search service.**\n",
    "\n",
    "If your indexes contain more than 100,000 documents, use this sample code to move your index onto a new search service. In contrast with the simple [backup and restore code](https://github.com/Azure/azure-search-vector-samples/tree/main/demo-python/code/utilities/index-backup-restore) that uses document paging to assemble the backup, this code uses timestamps to create a sorted list of documents for the backup. The code then uses filters to batch and move documents from one index to another.\n",
    "\n",
    "> **Note**: Azure AI Search now supports [service upgrades](https://learn.microsoft.com/azure/search/search-how-to-upgrade) and [pricing tier changes](https://learn.microsoft.com/azure/search/search-capacity-planning#change-your-pricing-tier). If you're backing up and restoring your index for migration to a higher capacity service, you now have other options.\n",
    "\n",
    "This code requires a [timestamp field](https://learn.microsoft.com/rest/api/searchservice/supported-data-types#edm-data-types-for-nonvector-fields) that indicates when a document was created and updated. It must be [filterable](https://learn.microsoft.com/azure/search/search-filters) and [sortable](https://learn.microsoft.com/azure/search/query-odata-filter-orderby-syntax). If you routinely update this timestamp every time you update a document in your index, you have a built-in record of the last time the document was changed and you can use this timestamp field to implement a resumable backup and restore. The most recently backed up timestamp can be recorded so a backup can pause at that timestamp and resume at a later time.\n",
    "\n",
    "You can also run parallel backup jobs to increase backup speed by setting partitions and backup jobs to greater than 1.  When using parallel backup jobs, consider the following limitations:\n",
    "\n",
    "* If documents are added to the index or existing documents are modified during the backup, modified or new documents are not included in the backup as they have a more recent timestamp than when the backup started.\n",
    "\n",
    "* Deletes during the backup may not be propogated to the backup copy of the index.  It's not recommended to delete any documents during a backup\n",
    "\n",
    "## Install packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install -r requirements.txt --quiet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load environment variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dotenv import load_dotenv\n",
    "from azure.identity.aio import DefaultAzureCredential\n",
    "from azure.core.credentials import AzureKeyCredential\n",
    "import os\n",
    "\n",
    "# Copy sample.env to .env and change the variables for your service\n",
    "load_dotenv(override=True)\n",
    "\n",
    "# The sample.env contains variables than what's needed for this code. Ignore any variables not used here.\n",
    "# Provide a search service containing the source index for the backup operation\n",
    "source_endpoint = os.environ[\"AZURE_SEARCH_SOURCE_SERVICE_ENDPOINT\"]\n",
    "# Provide an admin API key if you're using key-based authentication. Using a key is optional. See https://learn.microsoft.com/azure/search/keyless-connections\n",
    "source_credential = AzureKeyCredential(os.getenv(\"AZURE_SEARCH_SOURCE_ADMIN_KEY\")) if os.getenv(\"AZURE_SEARCH_SOURCE_ADMIN_KEY\") else DefaultAzureCredential()\n",
    "# Provide a second search service as the destination for the new restored index\n",
    "destination_endpoint = os.environ[\"AZURE_SEARCH_DESTINATION_SERVICE_ENDPOINT\"]\n",
    "destination_credential = AzureKeyCredential(os.getenv(\"AZURE_SEARCH_DESTINATION_ADMIN_KEY\")) if os.getenv(\"AZURE_SEARCH_DESTINATION_ADMIN_KEY\") else DefaultAzureCredential()\n",
    "# Name of the index to be backed up\n",
    "index_name = os.getenv(\"AZURE_SEARCH_INDEX\", \"\")\n",
    "# Optionally, multiple indexes can be specified as a comma-separated list. If not specified, the value of AZURE_SEARCH_INDEX is used.\n",
    "index_names = os.getenv(\"AZURE_SEARCH_INDEXES\", index_name).split(\",\") if \"AZURE_SEARCH_INDEXES\" in os.environ else [index_name]\n",
    "# Name of the timestamp field \n",
    "timestamp_field_name = os.environ[\"AZURE_SEARCH_TIMESTAMP_FIELD\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from azure.search.documents.indexes.aio import SearchIndexClient\n",
    "from azure.search.documents.aio import SearchClient\n",
    "from azure.search.documents.indexes.models import BinaryQuantizationCompression, SearchField\n",
    "from datetime import datetime, timedelta\n",
    "from uuid import uuid4\n",
    "import random\n",
    "\n",
    "enable_compression = False\n",
    "\n",
    "# Copies an index definition from the source service to the destination\n",
    "async def copy_index_definition(source_index_client: SearchIndexClient, destination_index_client: SearchIndexClient, index_name: str):\n",
    "    index = await source_index_client.get_index(index_name)\n",
    "    # Check for any synonym maps\n",
    "    synonym_map_names = []\n",
    "    for field in index.fields:\n",
    "        if field.synonym_map_names:\n",
    "            synonym_map_names.extend(field.synonym_map_names)\n",
    "    \n",
    "    # Copy over synonym maps if they exist\n",
    "    for synonym_map_name in synonym_map_names:\n",
    "        synonym_map = await source_index_client.get_synonym_map(synonym_map_name)\n",
    "        await destination_index_client.create_or_update_synonym_map(synonym_map)\n",
    "\n",
    "    if enable_compression:\n",
    "        for profile in index.vector_search.profiles:\n",
    "            if not profile.compression_name:\n",
    "                profile.compression_name = \"mycompression\"\n",
    "        \n",
    "        index.vector_search.compressions.append(\n",
    "            BinaryQuantizationCompression(\n",
    "                compression_name=\"mycompression\",\n",
    "                rerank_with_original_vectors=True,\n",
    "                default_oversampling=10\n",
    "            ))\n",
    "    \n",
    "    # Copy over the index\n",
    "    await destination_index_client.create_or_update_index(index)\n",
    "\n",
    "# Method to convert a timestamp to datetime\n",
    "def datetime_to_timestamp(date: datetime) -> str:\n",
    "    # Trim microseconds to milliseconds. Timestamp precision is to milliseconds only. See https://learn.microsoft.com/rest/api/searchservice/supported-data-types#edm-data-types-for-nonvector-fields for more information\n",
    "    return date.strftime(\"%Y-%m-%dT%H:%M:%S.%fZ\")[:-3] + \"Z\"\n",
    "\n",
    "def get_random_timestamp(start_time: datetime, end_time: datetime) -> str:\n",
    "    delta = end_time - start_time\n",
    "    random_seconds = random.randint(0, int(delta.total_seconds()))\n",
    "    return datetime_to_timestamp(start_time + timedelta(seconds=random_seconds))\n",
    "\n",
    "# Add a timestamp field to the index\n",
    "async def add_timestamp_to_index(source_index_client: SearchIndexClient, source_client: SearchClient, index_name: str, timestamp_field_name: str, start_timestamp: datetime, end_timestamp: datetime):\n",
    "    index = await source_index_client.get_index(index_name)\n",
    "    timestamp_field_added = False\n",
    "    key_field = None\n",
    "    for field in index.fields:\n",
    "        if not key_field and field.key:\n",
    "            key_field = field\n",
    "        if field.name == timestamp_field_name:\n",
    "            timestamp_field_added = True\n",
    "\n",
    "    if not timestamp_field_added:\n",
    "        index.fields.append(SearchField(name=timestamp_field_name, type=\"Edm.DateTimeOffset\", facetable=False, filterable=True, sortable=True, hidden=False))\n",
    "\n",
    "    await source_index_client.create_or_update_index(index)\n",
    "\n",
    "    # Create a session when paging through results to ensure consistency in multi-replica services\n",
    "    # For more information, please see https://learn.microsoft.com/azure/search/index-similarity-and-scoring#scoring-statistics-and-sticky-sessions\n",
    "    session_id = str(uuid4())\n",
    "    get_next_results = True\n",
    "    while get_next_results:\n",
    "        total_results_size = 0\n",
    "        filter = f\"{timestamp_field_name} eq null\"\n",
    "        results = await source_client.search(\n",
    "            search_text=\"*\",\n",
    "            top=100000,\n",
    "            filter=filter,\n",
    "            session_id=session_id,\n",
    "            select=[key_field.name]\n",
    "        )\n",
    "\n",
    "        results_by_page = results.by_page()\n",
    "        async for page in results_by_page:\n",
    "            # Add a timestamp to this page of results\n",
    "            update_page = [{ key_field.name: item[key_field.name], timestamp_field_name: get_random_timestamp(start_timestamp, end_timestamp) } async for item in page]\n",
    "            if len(update_page) > 0:\n",
    "                await source_client.merge_documents(update_page)\n",
    "            total_results_size += len(update_page)\n",
    "        \n",
    "        # If any results were returned, it's possible there's more documents without a timestamp\n",
    "        # Continue the search\n",
    "        get_next_results = total_results_size > 0\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (Optional) Add a timestamp column\n",
    "\n",
    "If you don't have a timestamp column to use for resuming, you can add one by generating new timestamps. It's important to attempt to evenly distribute these timestamps across your index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datetime import datetime, time\n",
    "\n",
    "for index_name in index_names:\n",
    "    async with SearchIndexClient(endpoint=source_endpoint, credential=source_credential) as source_index_client, SearchClient(endpoint=source_endpoint, credential=source_credential, index_name=index_name) as source_client:\n",
    "        now = datetime.now()\n",
    "        start_of_day = datetime.combine(now.date(), time.min)\n",
    "        end_of_day = datetime.combine(now.date(), time.max)\n",
    "\n",
    "        await add_timestamp_to_index(source_index_client, source_client, index_name, timestamp_field_name, start_timestamp=start_of_day, end_timestamp=end_of_day)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Copy index definition\n",
    "Copy the source index definition to the destination service."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "source_index_client = SearchIndexClient(endpoint=source_endpoint, credential=source_credential)\n",
    "destination_index_client = SearchIndexClient(endpoint=destination_endpoint, credential=destination_credential)\n",
    "\n",
    "for index_name in index_names:\n",
    "    await copy_index_definition(source_index_client, destination_index_client, index_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from azure.search.documents.indexes.aio import SearchIndexClient\n",
    "from azure.search.documents.indexes.models import SearchFieldDataType\n",
    "from typing import List\n",
    "\n",
    "# Method to validate the timestamp field exists, is filterable, and is sortable\n",
    "async def validate_resume_backup_and_restore(index_client: SearchIndexClient, index_name: str, timestamp_field_name: str) -> bool:\n",
    "    index = await index_client.get_index(index_name)\n",
    "\n",
    "    found_field = False\n",
    "    for field in index.fields:\n",
    "        if field.name == timestamp_field_name:\n",
    "            found_field = True\n",
    "            if field.type != SearchFieldDataType.DateTimeOffset:\n",
    "                # Field must be a timestamp\n",
    "                return False\n",
    "            if not field.filterable:\n",
    "                # Field must be filterable\n",
    "                return False\n",
    "            if not field.sortable:\n",
    "                # Field must be sortable\n",
    "                return False\n",
    "            break\n",
    "    \n",
    "    # Field must exist on the index\n",
    "    return found_field\n",
    "\n",
    "# Method to validate which fields can and cannot be backed up\n",
    "async def validate_fields_backup_and_restore(index_client: SearchIndexClient, index_name: str) -> List[str]:\n",
    "    missing_fields = []\n",
    "    index = await index_client.get_index(index_name)\n",
    "    for field in index.fields:\n",
    "        message = \"\"\n",
    "        # Complex fields are not marked as stored - skip\n",
    "        if not field.stored and not field.fields:\n",
    "            message += f\"Field {field.name} cannot be backed up because it's not marked as stored\\n\"\n",
    "        elif field.hidden: \n",
    "            message += f\"Field {field.name} cannot be backed up because it's not marked as retrievable\\n\"\n",
    "        \n",
    "        if message:\n",
    "            missing_fields.append(message)\n",
    "    \n",
    "    return missing_fields\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Validate backup and restore\n",
    "\n",
    "* Make sure the timestamp field is filterable and sortable.\n",
    "* If a field is not marked as [stored](https://learn.microsoft.com/azure/search/vector-search-how-to-storage-options), it cannot be backed up.\n",
    "* If a field is not marked as [retrievable](https://learn.microsoft.com/azure/search/search-pagination-page-layout#result-composition), it won't be backed up.\n",
    "  * This setting may be changed if the field is marked as stored.\n",
    "  * If the field was not marked as stored, it cannot be marked as retrievable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for index_name in index_names:\n",
    "    can_resume_backup_and_restore = await validate_resume_backup_and_restore(source_index_client, index_name, timestamp_field_name)\n",
    "    if can_resume_backup_and_restore:\n",
    "        print(f\"Index {index_name} has a valid timestamp field and can use resumable backup and restore\")\n",
    "    else:\n",
    "        print(f\"Index {index_name} does not have a valid timestamp field and cannot use resumable backup and restore\")\n",
    "\n",
    "    missing_fields_messages = await validate_fields_backup_and_restore(source_index_client, index_name)\n",
    "    for message in missing_fields_messages:\n",
    "        print(message)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from azure.search.documents.aio import SearchClient\n",
    "from typing import Optional, AsyncGenerator, List, Callable, Tuple\n",
    "from tqdm.notebook import tqdm\n",
    "import ipywidgets as widgets\n",
    "from uuid import uuid4\n",
    "import asyncio\n",
    "from datetime import datetime, timedelta\n",
    "from dataclasses import dataclass, asdict\n",
    "from copy import deepcopy\n",
    "import os\n",
    "import json\n",
    "import re\n",
    "\n",
    "# Class representing a partition, a subset of an index that can be used to create parallel backup jobs\n",
    "@dataclass\n",
    "class Partition:\n",
    "    id: int\n",
    "    start: str\n",
    "    end: str\n",
    "    last: str\n",
    "\n",
    "# Method to check how many documents are remaining in an index. The check can be scoped down to a single part of an index by timestamp\n",
    "async def get_total_documents_remaining(client: SearchClient, timestamp_field_name: str, min_timestamp: Optional[str] = None, max_timestamp: Optional[str] = None) -> int:\n",
    "    filter = None\n",
    "    if min_timestamp and not max_timestamp:\n",
    "        # If a minimum timestamp is specified, check all documents greater than or equal to this timestaamp\n",
    "        filter = f\"{timestamp_field_name} ge {min_timestamp}\"\n",
    "    elif min_timestamp and max_timestamp:\n",
    "        # If minimum and maximum timestamps are specified, check all documents between these timestamps\n",
    "        filter = f\"{timestamp_field_name} ge {min_timestamp} and {timestamp_field_name} le {max_timestamp}\"\n",
    "    results = await client.search(\n",
    "        search_text=\"*\",\n",
    "        include_total_count=True,\n",
    "        filter=filter,\n",
    "        top=0\n",
    "    )\n",
    "    return await results.get_count()\n",
    "\n",
    "# Method to find either the minimum or maximum timestamp in an index\n",
    "async def get_timestamp_bound(client: SearchClient, timestamp_field_name: str, max: bool) -> Optional[str]:\n",
    "    result = await client.search(\n",
    "        search_text=\"*\",\n",
    "        order_by=f\"{timestamp_field_name} {'desc' if max else 'asc'}\",\n",
    "        top=1,\n",
    "        select=[timestamp_field_name]\n",
    "    )\n",
    "    result = [item async for item in result]\n",
    "    if len(result) == 0:\n",
    "        return None\n",
    "    return result[0][timestamp_field_name]\n",
    "\n",
    "# Methods to convert a timestamp to and from datetime\n",
    "def timestamp_to_datetime(timestamp: str) -> datetime:\n",
    "    try:\n",
    "        return datetime.strptime(timestamp, \"%Y-%m-%dT%H:%M:%S.%fZ\")\n",
    "    except ValueError:\n",
    "        return datetime.strptime(timestamp, \"%Y-%m-%dT%H:%M:%SZ\")\n",
    "def datetime_to_timestamp(date: datetime) -> str:\n",
    "    # Trim microseconds to milliseconds. Timestamp precision is to milliseconds only. See https://learn.microsoft.com/rest/api/searchservice/supported-data-types#edm-data-types-for-nonvector-fields for more information\n",
    "    return date.strftime(\"%Y-%m-%dT%H:%M:%S.%fZ\")[:-3] + \"Z\"\n",
    "\n",
    "# Method to get bounds of partitions for parallel backup jobs.\n",
    "# Set desired_partitions to 1 to disable parallel backup jobs\n",
    "async def get_partition_bounds(client: SearchClient, timestamp_field_name: str, desired_partitions: int = 2, partition_size_threshold: float = 0.05, min_timestamp: Optional[str] = None, max_timestamp: Optional[str] = None) -> List[datetime]:\n",
    "    # Determine the minimum and maximum timestamps to backup. Default to taking them from the index\n",
    "    if max_timestamp == None:\n",
    "        max_timestamp = await get_timestamp_bound(client, timestamp_field_name, max=True)\n",
    "        if max_timestamp == None:\n",
    "            return []\n",
    "    if min_timestamp == None:\n",
    "        min_timestamp = await get_timestamp_bound(client, timestamp_field_name, max=False)\n",
    "\n",
    "    # If there's only 1 timestamp or parallel backup jobs are disabled, do not partition\n",
    "    if min_timestamp == max_timestamp or desired_partitions == 1:\n",
    "        return []\n",
    "\n",
    "    # Attempt to divide the index into roughly equally sized partitions\n",
    "    # Partitions are not guaranteed to be of a similar size. The timestamp distribution of data in your index affects the size of each partition\n",
    "    partition_splits = []\n",
    "    low = timestamp_to_datetime(min_timestamp)\n",
    "    for partition in range(desired_partitions - 1):\n",
    "        high = timestamp_to_datetime(max_timestamp)\n",
    "        remaining_partitions = desired_partitions - partition\n",
    "        # Determine the goal size a partition should be. This is the total amount of unpartitioned documents over the number of partitions left to create\n",
    "        # Partitions may be different sizes, specify a target percentage this partition size may be different from the target size\n",
    "        # For example, it may be acceptable for partitions to be 8% larger than another partition\n",
    "        # Unevely sized partitions may affect the speed of the parallel backup jobs\n",
    "        target_partition_size = await get_total_documents_remaining(client, timestamp_field_name, min_timestamp=datetime_to_timestamp(low)) // remaining_partitions\n",
    "        partition_threshold = target_partition_size * partition_size_threshold\n",
    "        # If an optimal partition size cannot be picked, track all the potential partition sizes to pick the best one\n",
    "        partition_sizes = []\n",
    "\n",
    "        # Perform a modified binary search to determine the bounds of the partition\n",
    "        best_split = None\n",
    "        mid = low + (high - low) / 2\n",
    "        while low <= high:\n",
    "            current_partition_size = await get_total_documents_remaining(client, timestamp_field_name, datetime_to_timestamp(low), datetime_to_timestamp(mid))\n",
    "            partition_sizes.append((mid, current_partition_size))\n",
    "            # Check if the partition is an acceptable size. If it's not, continue the binary search\n",
    "            if current_partition_size < target_partition_size + partition_threshold and current_partition_size > target_partition_size - partition_threshold:\n",
    "                best_split = mid\n",
    "                break\n",
    "            elif current_partition_size < target_partition_size:\n",
    "                mid = mid + (high - mid) / 2\n",
    "            else:\n",
    "                prev_high = high\n",
    "                high = mid\n",
    "                mid = mid - (mid - low) / 2\n",
    "                if prev_high == high:\n",
    "                    # No progress being made\n",
    "                    best_split = None\n",
    "                    break\n",
    "        \n",
    "        # If an acceptable partition could not be found, pick the one that has the closest size\n",
    "        if best_split is None:\n",
    "            min_difference = -1\n",
    "            for split, partition_size in partition_sizes:\n",
    "                difference = abs(target_partition_size - partition_size)\n",
    "                if min_difference == -1 or difference < min_difference:\n",
    "                    best_split = split\n",
    "                    min_difference = difference\n",
    "\n",
    "        if best_split:\n",
    "            partition_splits.append(best_split)\n",
    "            low = best_split + timedelta(milliseconds=1)\n",
    "        else:\n",
    "            # Cannot partition anymore, exit\n",
    "            partition_splits.append(low)\n",
    "            break\n",
    "\n",
    "    return partition_splits\n",
    "\n",
    "# Method to create partitions for parallel backup jobs\n",
    "# Requires using the bounds from the previous method\n",
    "async def get_partitions(client: SearchClient, timestamp_field_name: str, partition_splits: List[datetime], start_id: int = 0, min_timestamp: Optional[str] = None, max_timestamp: Optional[str] = None) -> List[Tuple[str, str]]:\n",
    "    # The minimum and maximum timestamps in the source index are part of the partition bounds\n",
    "    if max_timestamp  == None:\n",
    "        max_timestamp = await get_timestamp_bound(client, timestamp_field_name, max=True)\n",
    "        if max_timestamp == None:\n",
    "            return []\n",
    "    if min_timestamp == None:\n",
    "        min_timestamp = await get_timestamp_bound(client, timestamp_field_name, max=False)\n",
    "\n",
    "    # Create a new partition for every pair of bounds\n",
    "    prev_partition_end = timestamp_to_datetime(min_timestamp)\n",
    "    partitions = []\n",
    "    for i, partition_split in enumerate(partition_splits):\n",
    "        partitions.append(Partition(id=start_id + i, start=datetime_to_timestamp(prev_partition_end), end=datetime_to_timestamp(partition_split), last=None))\n",
    "        # The next partition starts 1 millisecond after the previous one to avoid overlap\n",
    "        prev_partition_end = partition_split + timedelta(milliseconds=1)\n",
    "    partitions.append(Partition(id=start_id + len(partition_splits), start=datetime_to_timestamp(prev_partition_end), end=max_timestamp, last=None))\n",
    "    return partitions\n",
    "\n",
    "# Resume fetching search results from a source index for backup.\n",
    "# May have timestamp bounds if resuming from a previous backup job or using parallel backup jobs\n",
    "async def resume_backup_results(client: SearchClient, timestamp_field_name: str, timestamp: Optional[str], max_timestamp: Optional[str] = None, select=None) -> AsyncGenerator[List[dict], None]:\n",
    "    # Create a session when paging through results to ensure consistency in multi-replica services\n",
    "    # For more information, please see https://learn.microsoft.com/azure/search/index-similarity-and-scoring#scoring-statistics-and-sticky-sessions\n",
    "    session_id = str(uuid4())\n",
    "    # The maximum number of results from a single search query is 100,000. This can be exceeded by using sorting and filtering\n",
    "    # For more information, please see https://learn.microsoft.com/azure/search/search-pagination-page-layout#paging-through-a-large-number-of-results\n",
    "    max_results_size = 100000\n",
    "    get_next_results = True\n",
    "    while get_next_results:\n",
    "        total_results_size = 0\n",
    "        filter = None\n",
    "        if timestamp and not max_timestamp:\n",
    "            # If using a single timestamp, find all records greater or equal than it\n",
    "            filter = f\"{timestamp_field_name} ge {timestamp}\"\n",
    "        elif timestamp and max_timestamp:\n",
    "            # If using a minimum and maximum timestamp, find all records between them\n",
    "            filter = f\"{timestamp_field_name} ge {timestamp} and {timestamp_field_name} le {max_timestamp}\"\n",
    "        results = await client.search(\n",
    "            search_text=\"*\",\n",
    "            order_by=f\"{timestamp_field_name} asc\",\n",
    "            top=max_results_size,\n",
    "            filter=filter,\n",
    "            session_id=session_id,\n",
    "            select=select\n",
    "        )\n",
    "        results_by_page = results.by_page()\n",
    "\n",
    "        async for page in results_by_page:\n",
    "            next_page = [item async for item in page]\n",
    "            # Count how many results are returned\n",
    "            total_results_size += len(next_page)\n",
    "            if len(next_page) == 0:\n",
    "                break\n",
    "            yield next_page\n",
    "            timestamp = next_page[-1][timestamp_field_name]\n",
    "        \n",
    "        # If the maximum amount of results were returned, it's possible there's more results after the last timestamp searched\n",
    "        # Continue the search using the most recent timestamp\n",
    "        get_next_results = total_results_size == max_results_size\n",
    "\n",
    "# Method to initiate a backup of a search service\n",
    "# The numer of partitions (whether to use parallel backup jobs) and number of parallel backup uploads is configurable\n",
    "# The strategy used to save partition state is configurable using on_backup_page\n",
    "async def backup_index_with_resume(client: SearchClient, destination_client: SearchClient, timestamp_field_name: str, partitions: List[Partition], backup_tasks:int = 2, on_backup_page: Optional[Callable[[Partition], None]] = None) -> None:\n",
    "    total_documents = 0\n",
    "    total_partitions = len(partitions)\n",
    "    for partition in partitions:\n",
    "        total_documents += await get_total_documents_remaining(client, timestamp_field_name, partition.last or partition.start, partition.end)\n",
    "    if total_documents == 0:\n",
    "        return\n",
    "    \n",
    "    # Create a progress bar to visualize backup progress\n",
    "    # Create a lable to track how many result pages are waiting for backup\n",
    "    progress_bar = tqdm(total=total_documents, desc=\"Backing up documents...\", unit=\"docs\", unit_scale=False)\n",
    "    pages_label = widgets.Label(value=\"Queued Result Pages: 0\")\n",
    "    display(pages_label)\n",
    "    \n",
    "    # Method to fetch all the search results for a backup job and queue them for backup\n",
    "    async def get_results(partition: Partition, results_queue: asyncio.Queue):\n",
    "        try:\n",
    "            results = resume_backup_results(client, timestamp_field_name, timestamp=partition.last or partition.start, max_timestamp=partition.end)\n",
    "            async for result_page in results:\n",
    "                await results_queue.put((partition, result_page))\n",
    "            await results_queue.put((partition, None))\n",
    "        except asyncio.CancelledError:\n",
    "            raise\n",
    "    \n",
    "    # Track how many parallel backup jobs have finished\n",
    "    finished_partitions = 0\n",
    "    finished_partitions_lock = asyncio.Lock()\n",
    "\n",
    "    # Track backup job tasks\n",
    "    backup_task_list = []\n",
    "\n",
    "    # Method to upload the documents taking in consideration that documents may be too large to upload in a single request\n",
    "    async def upload_documents(client: SearchClient, documents: List[dict], initial_batch_size: Optional[int] = None) -> int:\n",
    "        stack = [documents]\n",
    "        if initial_batch_size and initial_batch_size < len(documents):\n",
    "            stack = [\n",
    "                documents[i : i + initial_batch_size]\n",
    "                for i in range(0, len(documents), initial_batch_size)\n",
    "            ]\n",
    "        # record largest successful size\n",
    "        max_success = 0\n",
    "\n",
    "        while stack:\n",
    "            batch = stack.pop()\n",
    "            try:\n",
    "                await client.upload_documents(batch)\n",
    "                max_success = max(max_success, len(batch))\n",
    "            except TypeError as e: # Exception raised when a document is too large\n",
    "                if len(batch) == 1:\n",
    "                    # Single document too large\n",
    "                    raise\n",
    "                mid = len(batch) // 2\n",
    "                # Push halves onto stack to retry\n",
    "                stack.append(batch[mid:])\n",
    "                stack.append(batch[:mid])\n",
    "        \n",
    "        return max_success\n",
    "\n",
    "    # Method to fetch search results from a backup queue and back them up\n",
    "    async def backup_results(results_queue: asyncio.Queue, partition_update_queue: asyncio.Queue):\n",
    "        nonlocal finished_partitions\n",
    "        batch_size: Optional[int] = None\n",
    "        try:\n",
    "            while True:\n",
    "                partition, result_page = await results_queue.get()\n",
    "                if partition is None:\n",
    "                    # Exit\n",
    "                    break\n",
    "\n",
    "                if result_page is None:\n",
    "                    # The backup job completed. If all backup jobs have completed, exit\n",
    "                    async with finished_partitions_lock:\n",
    "                        finished_partitions += 1\n",
    "                        if finished_partitions >= total_partitions:\n",
    "                            # Ensure checkpoint job ends\n",
    "                            await partition_update_queue.put(None)\n",
    "                            # Ensure backup jobs ends\n",
    "                            for _ in backup_task_list:\n",
    "                                await results_queue.put((None, None))\n",
    "                            progress_bar.n = total_documents\n",
    "                            progress_bar.refresh()\n",
    "                    break\n",
    "                \n",
    "                # Update the partition state with the most recently completed backup\n",
    "                saved_timestamp = result_page[-1][timestamp_field_name]\n",
    "                partition.last = saved_timestamp\n",
    "\n",
    "                # Back up the search results and queue an update to the partition\n",
    "                batch_size = await upload_documents(destination_client, result_page, batch_size)\n",
    "                await partition_update_queue.put(deepcopy(partition))\n",
    "                if progress_bar.n < progress_bar.total:\n",
    "                    progress_bar.update(len(result_page))\n",
    "        except asyncio.CancelledError:\n",
    "            raise\n",
    "    \n",
    "    # Helper method to save a partition's state if it's been updated\n",
    "    async def checkpoint_results(partition_update_queue: asyncio.Queue, output_queue: asyncio.Queue):\n",
    "        partition_max_timestamps = {}\n",
    "        try:\n",
    "            while True:\n",
    "                partition = await partition_update_queue.get()\n",
    "                if partition is None:\n",
    "                    # No more updates, all backup jobs finished\n",
    "                    break\n",
    "                pages_label.value=f\"Queued Result Pages: {output_queue.qsize()}\"\n",
    "\n",
    "                # Only update this partition if this is the most recently processed update to the partition\n",
    "                max_timestamp = partition_max_timestamps.get(partition.id)\n",
    "                last_timestamp = timestamp_to_datetime(partition.last)\n",
    "                if not max_timestamp or last_timestamp >= max_timestamp:\n",
    "                    partition_max_timestamps[partition.id] = last_timestamp\n",
    "                    on_backup_page(partition)\n",
    "        except asyncio.CancelledError:\n",
    "            raise\n",
    "\n",
    "    results_queue = asyncio.Queue()\n",
    "    partition_update_queue = asyncio.Queue()\n",
    "\n",
    "    # Run producer and consumer concurrently\n",
    "    result_task_list = [asyncio.create_task(get_results(partition, results_queue)) for partition in partitions]\n",
    "    backup_task_list.extend([asyncio.create_task(backup_results(results_queue, partition_update_queue)) for _ in range(backup_tasks)])\n",
    "    checkpoint_task = asyncio.create_task(checkpoint_results(partition_update_queue, results_queue))\n",
    "\n",
    "    # Wait for all tasks to complete\n",
    "    try:\n",
    "        await asyncio.gather(*result_task_list)\n",
    "        await asyncio.gather(*backup_task_list)\n",
    "        await checkpoint_task\n",
    "    except asyncio.CancelledError:\n",
    "        for task in result_task_list:\n",
    "            task.cancel()\n",
    "        for task in backup_task_list:\n",
    "            task.cancel()\n",
    "        checkpoint_task.cancel()\n",
    "        await asyncio.gather(*result_task_list, return_exceptions=True)\n",
    "        await asyncio.gather(*backup_task_list, return_exceptions=True)\n",
    "        try:\n",
    "            await checkpoint_task\n",
    "        except asyncio.CancelledError:\n",
    "            pass\n",
    "\n",
    "# Create incremental parittion files from a previous backup job\n",
    "async def create_incremental_backup_partitions(client: SearchClient, timestamp_field_name: str, partitions: List[Partition], desired_partitions: int = 1, partition_size_threshold: float = 0.05, max_timestamp: Optional[str] = None) -> List[Partition]:\n",
    "    if not max_timestamp:\n",
    "        max_timestamp = await get_timestamp_bound(client, timestamp_field_name, max=True)\n",
    "    min_timestamp = None\n",
    "    last_id = None\n",
    "    for partition in partitions:\n",
    "        if not last_id:\n",
    "            last_id = partition.id\n",
    "        elif partition.id > last_id:\n",
    "            last_id = partition.id\n",
    "        if not min_timestamp:\n",
    "            min_timestamp = partition.last\n",
    "        elif partition.last:\n",
    "            if timestamp_to_datetime(partition.last) > timestamp_to_datetime(min_timestamp):\n",
    "                min_timestamp = partition.last\n",
    "    \n",
    "    partition_splits = await get_partition_bounds(client, timestamp_field_name, desired_partitions, partition_size_threshold, min_timestamp)\n",
    "    partitions = await get_partitions(client, timestamp_field_name, partition_splits, start_id=last_id + 1, min_timestamp=min_timestamp, max_timestamp=max_timestamp)\n",
    "    return partitions\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initiate the backup\n",
    "\n",
    "* Set `desired_partitions` to a value greater than 1 to set up parallel backup jobs.\n",
    "* Change `backup_tasks` to determine how many parallel backup workers attempt to update the destination result with results from the source service.\n",
    "* Changing `desired_partitions` and `backup_tasks` will change the speed of the backup.\n",
    "  * Services with more [replicas](https://learn.microsoft.com/azure/search/search-capacity-planning#concepts-search-units-replicas-partitions) or a higher [SKU](https://learn.microsoft.com/azure/search/search-sku-tier) may benefit from a higher number of parallel backup jobs and parallel backup workers.\n",
    "* Use `create_incremental_backup_partitions` to resume from a previous backup job if records have been added or updated.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "async def backup_index(index_name: str, desired_partitions: int = 1, incremental_backup: bool = True):\n",
    "    async with SearchClient(source_endpoint, index_name, source_credential) as source_client, \\\n",
    "               SearchClient(destination_endpoint, index_name, destination_credential) as destination_client:    \n",
    "\n",
    "        # Example implementation to store backup job state\n",
    "        # Saves partition JSON to files in the \"partitions\" directory\n",
    "        backup_state_directory = os.path.join(\"partitions\", index_name)\n",
    "        backup_format = os.path.join(backup_state_directory, \"backup-{partition}.json\")\n",
    "        if not os.path.exists(backup_state_directory):\n",
    "            os.makedirs(backup_state_directory)\n",
    "        def on_backup_page(partition: Partition) -> None:\n",
    "            with open(backup_format.format(partition=partition.id), \"w\") as f:\n",
    "                json.dump(asdict(partition), f, indent=2)\n",
    "\n",
    "        # Restore partition JSON from files in the \"partitions\" directory\n",
    "        def read_backups_state(directory: str) -> List[Partition]:\n",
    "            if not os.path.isdir(directory):\n",
    "                return []\n",
    "            partitions = []\n",
    "            for file in os.listdir(directory):\n",
    "                if re.match(r'backup-\\d+\\.json', file):\n",
    "                    with open(os.path.join(directory, file), \"r\") as f:\n",
    "                        data = json.load(f)\n",
    "                        partitions.append(Partition(**data))\n",
    "\n",
    "            return partitions\n",
    "\n",
    "        partitions = None\n",
    "        incremental_partitions = None\n",
    "        if desired_partitions == 1:\n",
    "            # Resume backup from the last timestamp in the destination index\n",
    "            incremental_partitions = [\n",
    "                Partition(\n",
    "                    id=0,\n",
    "                    start=await get_timestamp_bound(destination_client, timestamp_field_name, max=True),\n",
    "                    end=await get_timestamp_bound(source_client, timestamp_field_name, max=True),\n",
    "                    last=None\n",
    "                )\n",
    "            ]\n",
    "        else:\n",
    "            partitions = read_backups_state(backup_state_directory)\n",
    "\n",
    "        if not partitions:\n",
    "            partition_splits = await get_partition_bounds(source_client, timestamp_field_name, desired_partitions=desired_partitions)\n",
    "            partitions = await get_partitions(source_client, timestamp_field_name, partition_splits)\n",
    "        elif incremental_backup and not incremental_partitions and partitions:\n",
    "            incremental_partitions = await create_incremental_backup_partitions(source_client, timestamp_field_name, partitions, desired_partitions=desired_partitions)\n",
    "\n",
    "        await backup_index_with_resume(\n",
    "            source_client,\n",
    "            destination_client,\n",
    "            timestamp_field_name,\n",
    "            partitions=incremental_partitions or partitions,\n",
    "            on_backup_page=on_backup_page,\n",
    "            backup_tasks=desired_partitions * 2\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for index_name in index_names:\n",
    "    await backup_index(index_name)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
